import numpy as np
from stable_baselines3 import TD3, A2C, PPO, SAC
from stable_baselines3.common.noise import NormalActionNoise
from utils import ALGOS, create_test_env
from adversary_env_wrapper import AdversaryEnvWrapper
import torch
import random
import imageio
from collections import OrderedDict
import os
from datetime import datetime
import wandb
import matplotlib.pyplot as plt
from victim_env_wrapper import VictimEnvWrapper
from utils.exp_manager import ExperimentManager
from copy import deepcopy as cp
import sys
from sb3_contrib import RecurrentPPO

algos = dict(
        TD3=TD3,
        A2C=A2C,
        PPO=PPO,
        SAC=SAC,
        ppo_lstm=RecurrentPPO
    )

### ADVERSARY RELATED

def get_model_fname(model_timestep, agent):
    assert agent in ["victim", "adv"]

    if model_timestep == "latest":
        fname_model_adv = f"logs/{agent}_rl_model_latest.zip"
    else:
        fname_model_adv = f"logs/{agent}_rl_model_{int(model_timestep)}_steps.zip"
    return fname_model_adv


## GENERAL

def get_setup(args, timestep_victim=None, timestep_adv=None, run_mnp_attack=False):

    adv_env_hyperparams_wrapped = OrderedDict([('normalize', False), ("env_wrapper",
                                                                      OrderedDict([("adversary_env_wrapper.AdversaryEnvWrapper",
                                                                                    dict(args=args,
                                                                                         run_mnp_attack=run_mnp_attack))]))])

    # get base environment of adversary
    env_for_adv = create_test_env(
        args.env,
        n_envs=1,
        stats_path=None,
        seed=args.seed,
        log_dir=None,
        should_render=0,
        hyperparams=adv_env_hyperparams_wrapped,
        env_kwargs={},
    )

    if "Hopper" in args.env or "HalfCheetah" in args.env:
        # base_folder = "noised_animals_sac_v2"
        base_folder = "noised_animals_sac_v3"
    else:
        base_folder = "noised_agents_final_old"

    # set up victim
    tensorboard_log = None if not args.use_wandb else f"runs/{args.wandb_run_id}/vic"
    if timestep_victim is None:
        if args.victim_loadfrom == -1:
            trained_agent = ""
        elif args.victim_loadfrom == "naive":
            trained_agent = f"/users/anonymous/Code_onlyremote/rl-zoo/{base_folder}/" \
                            f"{args.victim_algo}/noise_sigma_{float(args.victim_noise_sigma)}" \
                            f"_seed_{args.seed}/{args.env}/1/{args.env}.zip"
        elif args.victim_loadfrom == "cotrain":
            trained_agent = "/work/anonymous/illusionary/cotrain_ole_v1/"

            if "HalfCheetah" in args.env or "Hopper" in args.env:
                trained_agent += f"{args.env}/action_scale={args.action_scale},append_l2=True,env={args.env}," \
                                f"evaluate=False,experiment_id_new=cotrain,illusionary_reward_weight=0," \
                                f"n_adv_steps_per_vic_step=1,seed={args.seed},victim_algo=ppo_lstm,wandb_tag=resub_runs_v1"
            else:
                trained_agent += f"{args.env}/action_scale=0.2,env={args.env}," \
                                    f"evaluate=False,experiment_id_new=cotrain,illusionary_reward_weight=0," \
                                    f"n_adv_steps_per_vic_step=1,seed={args.seed},victim_algo=ppo_lstm,victim_loadfrom=-1,wandb_tag=expfinaleole"
                
            trained_agent += "/logs/victim_rl_model_latest.zip"
        elif args.victim_loadfrom == "illtrain":
            illu_weight_to_load = args.illusionary_reward_weight if args.illusionary_reward_weight != 0 else 10
            trained_agent = "/work/anonymous/illusionary/illtrain_ole_v1/"

            if "HalfCheetah" in args.env or "Hopper" in args.env:
                trained_agent += f"{args.env}/action_scale=={args.action_scale},adv_append_wm=True,append_l2=True,env={args.env}," \
                                f"evaluate=False,experiment_id_new=illtrain,illusionary_reward_weight={illu_weight_to_load}," \
                                f"n_adv_steps_per_vic_step=1,seed={args.seed},victim_algo=ppo_lstm,wandb_tag=resub_runs_v1"

            else:
                trained_agent += f"{args.env}/action_scale=={args.action_scale},adv_append_wm=True,env={args.env}," \
                             f"evaluate=False,experiment_id_new=illtrain,illusionary_reward_weight={illu_weight_to_load}," \
                             f"n_adv_steps_per_vic_step=1,seed={args.seed},victim_algo=ppo_lstm,victim_loadfrom=-1,wandb_tag=expfinaleole"
            
            trained_agent += "/logs/victim_rl_model_latest.zip"
        else:
            raise NotImplementedError
    else:
        trained_agent = get_model_fname(timestep_victim, agent="victim")

    vic_env_hyperparams_wrapped = {"env_wrapper": "victim_env_wrapper.VictimEnvWrapper"}
    exp_manager = ExperimentManager(
        args=None,
        algo=args.victim_algo,
        env_id=args.env,
        log_folder="logs",
        hyperparams=vic_env_hyperparams_wrapped,
        tensorboard_log=tensorboard_log,
        trained_agent=trained_agent) # set trained agent to pretrained agent if you wish to load sth here
    model_vic, _ = exp_manager.setup_experiment()

    # set up adversary
    tensorboard_log = None if not args.use_wandb else f"runs/{args.wandb_run_id}/adv"
    if timestep_adv is None:
        model_adv  = algos[args.adv_algo]("MlpPolicy", env_for_adv, verbose=1, seed=args.seed, tensorboard_log=tensorboard_log)
    else:
        model_adv  = algos[args.adv_algo].load(get_model_fname(timestep_adv, agent="adv"))
        model_adv.env = env_for_adv

    # initialize correct connections
    # TODO  do this nicer / correctly
    for _env in model_vic.env.envs:
        _env.model_adv = model_adv
    model_adv.env.envs[0].victim_agent_model = model_vic

    return model_vic, model_adv


def process_recorded_frames(frames_right, frames_left, fps, args, log_title="final"):

    # folder_local = "images"
    stamp = datetime.now().strftime("%Y_%m_%d-%I:%M:%S_%p") + "_seed_" + str(args.seed)
    folder_central = os.path.join("/work/anonymous/gifs_out_animals_v3", str(args.experiment_id_new), args.env, f"scale_{args.action_scale}", f"ill_rweight_{args.illusionary_reward_weight}", stamp)

    # os.makedirs(folder_local, exist_ok=True)
    os.makedirs(folder_central, exist_ok=True)

    dsp_factor = 2
    segment_length = 10 if "CartPole" in args.env else 500

    cuts = list(np.where(np.array([np.sum(elem) == 0 for elem in frames_left]) == 1)[0])
    cuts = [-1] + cuts

    for index in range(len(cuts)-1):
        begin = cuts[index]
        end = cuts[index+1]

        frames_block_left = frames_left[begin+1:end]
        frames_block_right = frames_right[begin+1:end]
        frames_block_combined = combine_frame_sets(frames_block_left, frames_block_right)

        # save at full resolution
        # save_frames_as_gif(folder_central, frames_block_left, fps, f"block_{index}_left")
        # save_frames_as_gif(folder_central, frames_block_right, fps, f"block_{index}_right")
        # save_frames_as_gif(folder_central, frames_block_combined, fps, f"block_{index}_combined")

        # save segments at full resolution
        # save_frames_as_gif(folder_central, frames_block_left[:segment_length], fps, f"block_{index}_segment_left")
        # save_frames_as_gif(folder_central, frames_block_right[:segment_length], fps, f"block_{index}_segment_right")
        save_frames_as_gif(folder_central, frames_block_combined[:segment_length], fps, f"block_{index}_segment_combined")

        # save at downsampled resolution
        # save_frames_as_gif(folder_central, downsample(frames_block_left, dsp_factor), fps, "dsp_left")
        # save_frames_as_gif(folder_central, downsample(frames_block_right, dsp_factor), fps, "dsp_right")
        # save_frames_as_gif(folder_central, downsample(frames_block_combined, dsp_factor), fps, "dsp_combined")

        if args.use_wandb:
            gif_out = np.stack(downsample(frames_block_combined[:segment_length], dsp_factor), axis=0)
            gif_out = np.moveaxis(gif_out, 3, 1)

            wandb.log({f"{log_title}_block_{index}": wandb.Video(gif_out, fps=fps, format="gif")})

        print(f"done with block {index}")


def downsample(frames, dsp_factor):
    return [frame[::dsp_factor, ::dsp_factor, :] for frame in frames]

def combine_frame_sets(frames_a, frames_b):

    assert len(frames_a) == len(frames_b)
    n_frames = len(frames_a)

    height = frames_a[0].shape[0]
    separator = np.zeros((height, 3, 3), dtype=np.uint8)
    frames_out = []

    for k in range(n_frames):
        frame_concatenated = np.concatenate((frames_a[k], separator, frames_b[k]), axis=1)
        frames_out.append(frame_concatenated)

    return frames_out

def save_frames_as_gif(folder, frames_in, fps, identifier):
    time_str = datetime.now().strftime("%Y_%m_%d-%I:%M:%S_%p")
    output_path = os.path.join(folder, f'{time_str}_{identifier}.gif')
    frames_in.append(np.zeros(frames_in[0].shape, dtype=np.uint8))
    imageio.mimsave(output_path, frames_in, format="GIF-PIL", fps=fps, loop=1)

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def save_noise_fig(all_metrics):
    ALL_COLORS = ["red", "grey", "blue"]
    ALL_STYLES = ["solid", "dotted", "dashed", "dashdot"]

    all_noises = list(all_metrics['None'].keys())

    fig, ax = plt.subplots(figsize=(10, 6))

    for k, attack in enumerate(list(all_metrics.keys())):

        x = all_noises

        y = []
        y_std = []

        for noise in list(all_metrics[attack].keys()):
            rewards = all_metrics[attack][noise]["victim_rewards"]
            y.append(np.mean(rewards))
            y_std.append(np.std(rewards))

        y = np.array(y)
        y_std = np.array(y_std)

        ax.plot(x, y, color=ALL_COLORS[k], label=attack, linestyle=ALL_STYLES[k], linewidth=2, markersize=10)
        ax.fill_between(x, y-y_std, y+y_std, facecolor=ALL_COLORS[k], alpha=0.1)

    plt.xscale("log")
    plt.xlabel("Noise sigma used for noising of victim observation")
    plt.ylabel("Victim reward")
    plt.title("Analysis of effectiveness of randomized smoothing for different types of attacks")
    plt.legend()

    fpath = "noise_figure.png"
    plt.savefig(fpath)

    return fig, fpath

